#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/21 21:39 
@Description ：AEE 训练
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, AEE_Convd
from tqdm import tqdm
import torch.nn.functional as F
from PIL import Image

# import os
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)

EPS = 1e-15

if __name__ == '__main__':
    batch = 8
    datasetpath = 'G:\哨兵2号数据'
    trainSet = CarTiffDateSet(datasetpath)
    train_loader = torch.utils.data.DataLoader(dataset=trainSet,
                                               batch_size=batch,
                                               shuffle=True)
    # 加载模型
    Q = AEE_Convd.Q_net().to(device)
    P = AEE_Convd.P_net().to(device)
    D_gauss = AEE_Convd.D_net_gauss().to(device)

    # 优化器
    # encode/decode optimizers
    optim_P = torch.optim.Adam(P.parameters())
    optim_Q_enc = torch.optim.Adam(Q.parameters())
    # regularizing optimizers
    optim_Q_gen = torch.optim.Adam(Q.parameters())
    optim_D = torch.optim.Adam(D_gauss.parameters())
    cost = nn.MSELoss().to(device)
    # 开始训练
    for epoch in range(20):

        # 启用编码器、解码器、判别器中的随机失活（与批标准化）
        Q.train()
        P.train()
        D_gauss.train()

        D_loss_sum = torch.zeros(1).to(device)  # 累计损失
        G_loss_sum = torch.zeros(1).to(device)  # 累计损失
        recon_loss_sum = torch.zeros(1).to(device)  # 累计损失

        for step, data in enumerate(train_loader):
            images, labels = data

            # 自编码
            z_sample = Q(images)  # encode to z
            X_sample = P(z_sample)  # decode to X reconstruction

            # 二分类交叉熵，计算损失
            # print(images.shape,X_sample.shape)
            recon_loss = cost(X_sample + EPS, images + EPS)
            # recon_loss = F.binary_cross_entropy(X_sample + EPS, images + EPS)

            # 重构损失反向传播
            recon_loss.backward()
            # 优化更新编码器与解码器
            optim_P.step()
            optim_Q_enc.step()

            # 梯度清零，不希望之前的梯度影响之后的梯度
            P.zero_grad()
            Q.zero_grad()
            D_gauss.zero_grad()

            #######################
            # Regularization phase
            #######################
            # Discriminator
            # 判别器

            Q.eval()
            # 从一个高斯随机分布中采样隐变量
            z_real_gauss = torch.randn(images.size()[0], 64) * 5.
            # print(z_real_gauss.shape)
            z_real_gauss = z_real_gauss.to(device)

            # 由输入数据到编码器中产生的隐变量，目前不一定是高斯分布
            z_fake_gauss = Q(images)

            # 将两种隐变量分别输入到判别器中，每张图片的每种隐变量对应一个数
            D_fake_gauss = D_gauss(z_fake_gauss)
            D_real_gauss = D_gauss(z_real_gauss)

            # 判别器的损失函数
            # 判别器认为真实高斯隐变量为1,从训练数据得到的隐变量为0
            D_loss = -torch.mean(torch.log(D_real_gauss + EPS) + torch.log(1 - D_fake_gauss + EPS))

            # 判别器损失函数反向传播
            D_loss.backward()
            # 优化更新判别器
            optim_D.step()

            # 梯度再次清零
            P.zero_grad()
            Q.zero_grad()
            D_gauss.zero_grad()

            # Generator
            # 对于生成器/解码器
            # 启用编码器中的随机失活
            Q.train()

            # 由输入数据到编码器中产生的隐变量，目前不一定是高斯分布
            z_fake_gauss = Q(images)

            # 将隐变量输入到判别器中，每张图片的结果对应一个数
            D_fake_gauss = D_gauss(z_fake_gauss)
            # print(D_fake_gauss)

            # 生成器/解码器的损失函数
            # 生成器/解码器努力使由输入数据得到的隐变量标签为1，与判别器对抗训练，使得隐变量分布接近于高斯分布
            # 生成器/解码器的目标在于生成以假乱真的数据，与GAN不同在于以假乱真是指是否是高斯分布，而非是否是真实数据
            G_loss = -torch.mean(torch.log(D_fake_gauss + EPS))

            # 生成器/解码器损失函数反向传播
            G_loss.backward()

            # 优化更新生成器/解码器
            optim_Q_gen.step()

            # 清除梯度
            P.zero_grad()
            Q.zero_grad()
            D_gauss.zero_grad()

            D_loss_sum += D_loss.detach()
            G_loss_sum += G_loss.detach()
            recon_loss_sum += recon_loss.detach()
        print(f"epoch:{epoch};"
              f" D_loss_gauss:{D_loss_sum.item() / (step + 1):.3f};"
              f" G_loss:{G_loss_sum.item() / (step + 1):.3f};"
              f" recon_loss: {recon_loss_sum.item() / (step + 1):.3f} ")

    torch.save(Q.state_dict(), 'SavedModels/QNet.pkl')
